package common

import (
	"fmt"
	"go-shop-demo/app/model"
	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
	"log"
	"os"
	"time"
)

var _db *gorm.DB
var _tx *gorm.DB

func InitDB() {
	dsn := getDSN()
	newLogger := logger.New(
		log.New(os.Stdout, "", log.LstdFlags),
		logger.Config{
			SlowThreshold: time.Second,
			LogLevel:      logger.Info,
			Colorful:      true,
		},
	)
	db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{Logger: newLogger})
	if err != nil {
		panic("failed to connect mysql")
	}
	//迁移
	db.AutoMigrate(&model.User{}, &model.Goods{}, &model.Order{}, &model.OrderGoods{})

	sqlDB, err := db.DB()
	if err != nil {
		panic("not is mysql")
	}

	// SetMaxIdleConns 设置空闲连接池中连接的最大数量
	sqlDB.SetMaxIdleConns(10)

	// SetMaxOpenConns 设置打开数据库连接的最大数量。
	sqlDB.SetMaxOpenConns(20)

	// SetConnMaxLifetime 设置了连接可复用的最大时间。
	sqlDB.SetConnMaxLifetime(time.Hour)

	_db = db
}

func getDSN() string {
	root := Config.Db.Name
	password := Config.Db.Password
	host := Config.Db.Host
	port := Config.Db.Port
	database := Config.Db.Database
	return fmt.Sprintf("%v:%v@tcp(%v:%v)/%v?charset=utf8mb4&parseTime=True&loc=Local", root, password, host, port, database)
}

func GetDB() *gorm.DB {
	return _db
}

func GetTx() *gorm.DB {
	if _tx == nil {
		_tx = _db.Begin()
	}
	return _tx
}

func Commit() {
	if _tx == nil {
		return
	}
	_tx.Commit()
	_tx = nil
}

func RollBack() {
	if _tx == nil {
		return
	}
	_tx.Rollback()
	_tx = nil
}

func PostProcessTx(err interface{}) {
	if err != nil {
		RollBack()
	} else {
		Commit()
	}
}
